import os
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
import cartopy.crs as ccrs
from pyproj import Proj, CRS, Transformer
from matplotlib.colors import BoundaryNorm
from descartes import PolygonPatch
from wrf import getvar, get_cartopy, to_np
from netCDF4 import Dataset
from cartopy.io.img_tiles import GoogleTiles
import datetime as dt
import pandas as pd
import geopandas
import cartopy.io.shapereader as shpreader

from plot_marshall_functions import ext2bbox, load_baseimage_any
from plot_marshall_functions import gjson_feature_geometry, gjson2poly
from plot_marshall_functions import refine_dom, read_wrffr
from plot_marshall_functions import pyproj_wrf2wgs, cartopy_defwrfproj, mplext_outer

# --------------------------------------------------------------------
# --------------------------------------------------------------------
# Defaults

thisrun = 'ysu-les-large'
plottype = 'area'
nn_idx = 3  # number of wrfout-files to accumulate - set it to 0 to plot 1 file
n_idx = nn_idx
init_idx = 0  # file index to begin accumulation

geogrid_files = '/glade/scratch/frediani/WRF-Fire/runs/marshall/wps-tim/geo_em.d02.nc'
fpath = os.getcwd()
print("File out path:", fpath)

g = 9.81

# Date/time strings

# yapf:disable
ddhh = [
    '30_18:00', '30_19:25', '30_20:15', '30_21:00'
#    '30_19:00', '30_19:45', '30_22:05', '31_02:30'
]

ddhhlocal = [
  '30_11:00',  '30_12:25',  '30_13:15',  '30_14:00'
#  '30_12:00',  '30_12:45',  '30_15:05',  '30_19:30'
]

# --------------------------------------------------------------------
# --------------------------------------------------------------------
# Read WPS Geographic Data

f0 = geogrid_files

# Cartopy includes the datum in the proj4 string - use this for projected plotting
wpwrfcrs = cartopy_defwrfproj(f0)
#pltcrs = ccrs.epsg(3857)

# Conversion and transformations using PyProj
wrfcrs, mapcrs, wrfext, mapext, latlonext, wrfxy, mapxy, dxy = pyproj_wrf2wgs(wf=f0)
wrfdx = dxy[0]

# --------------------------------------------------------------------
# --------------------------------------------------------------------
# Read fire perim json file

perim1 = gjson_feature_geometry(filename='marshall.json')

# --------------------------------------------------------------------
# --------------------------------------------------------------------
# Base Image elements

zext = (mplext_outer(mapext) +
        np.array([-dxy[0] / 2., +dxy[1] / 2., -dxy[0] / 2., +dxy[1] / 2.])).tolist()

# Convert from lon/lat to projected coordinates in meters
map2meters = Proj(mapcrs, preserve_units=False)

# Polygon Patch (Perimeter / shape file)
polygon = [gjson2poly(
    ft=ii,
    prjtrans=map2meters,
) for ii in perim1]

# Background
baseext = (np.array(zext) + np.array([3200., -3200.0, 3200.0, -3200.0])).tolist()
base = load_baseimage_any('USGSTopo', bbox=ext2bbox(*baseext))


# --------------------------------------------------------------------
# --------------------------------------------------------------------
for ii in np.arange(nn_idx):
    for jj in np.arange(1):
        if jj == 1:
            case = 'big2_maria_run9'
        else:
            case = 'big2_maria_run9_add_spotting'
        wrf_files = '/glade/scratch/tjuliano/fire/marshall/joinwrfh/1min/' + case + '/fire'
        wrf_files2 = '/glade/scratch/tjuliano/fire/marshall/joinwrfh/1min/' + case + '/spotting'
        wrf_files3 = '/glade/scratch/tjuliano/fire/marshall/joinwrfh/1min/' + case + '/met'
        print(f'thisrun: {thisrun} plottype: {plottype}\ninit_idx: {init_idx} n_idx: {n_idx}')

# --------------------------------------------------------------------
# --------------------------------------------------------------------

        plt.ioff()
#plt.ion()

# --------------------------------------------------------------------
# --------------------------------------------------------------------
# Figure name

#fpath = os.getcwd() + f'/plots_{thisrun}/'
#if not os.path.exists(fpath):
#    os.mkdir(fpath)

        fn0 = thisrun.replace('/', '_')
        fn1 = plottype
        fn2 = f'{init_idx:02d}-{n_idx:02d}'
        fn3 = 'dt' + f'{(init_idx/4):02.2f}' + '-' + f'{(init_idx+n_idx)/4:02.2f}'
        fname = '_'.join([fn0, fn1, fn2, fn3])
        fname = fname.replace('.','_') + '_3panel_viirs.png'

        print("File out name:", fname)

# --------------------------------------------------------------------
# --------------------------------------------------------------------
# Read WRF Data

        fwrf = wrf_files
        fwrf2 = wrf_files2
        fwrf3 = wrf_files3
        fnamestr = 'wrfout_d02_2021-12-{}:00'

        farea = np.zeros(wrfxy[0].shape)
        refarea = np.zeros(4 * (np.array(farea.shape) + 1))

        if plottype == 'area':
            fw1 = fnamestr.format(ddhh[init_idx + n_idx])
            print("Area:", fw1)
            # farea = read_wrffr(os.path.join(fwrf, fw1), var='FS_FIRE_AREA')

            print("Refined Area:", fw1)
            refarea = read_wrffr(os.path.join(fwrf, fw1), var='FIRE_AREA')[0:-3, 0:-3]
            # coordinates don't exist in the netcdf file for the last 3 points
            refarea = refarea[0:-4, 0:-4]  # remove gridpoints beyond domain edges

            xlat2 = read_wrffr(os.path.join(fwrf2, fw1), var='XLAT')
            xlon2 = read_wrffr(os.path.join(fwrf2, fw1), var='XLONG')

            sp_sum_dep = read_wrffr(os.path.join(fwrf2, fw1), var='FS_COUNT_LANDED_HIST')
            # coordinates don't exist in the netcdf file for the last 3 points
#            sp_sum_dep = sp_sum_dep[0:-4, 0:-4]  # remove gridpoints beyond domain edges
            sp_sum_dep = 100.*(sp_sum_dep/np.nansum(sp_sum_dep.ravel()))
            #sp_sum_dep[sp_sum_dep<=0] = np.nan
            idxx = np.where(sp_sum_dep.ravel()<=0)[0]
            xlat2 = np.delete(xlat2.ravel(),idxx)
            xlon2 = np.delete(xlon2.ravel(),idxx)

            xlat = read_wrffr(os.path.join(fwrf2, fw1), var='XLAT')
            xlon = read_wrffr(os.path.join(fwrf2, fw1), var='XLONG')

            with Dataset(fwrf3 + '/' + fw1) as nc:
                u        = getvar(nc, 'ua').values
                u2       = getvar(nc, 'ua')
                v        = getvar(nc, 'va').values
                u10      = getvar(nc, 'U10').values
                v10      = getvar(nc, 'V10').values
                th       = getvar(nc, 'T').values
                th       = th + 300.
                ph       = getvar(nc, 'PH').values
                phb      = getvar(nc,'PHB').values
                zht      = ((ph[1:]+ph[0:-1])/2. +(phb[1:]+phb[0:-1])/2.)/9.81
                hgt      = getvar(nc,'HGT').values

            wrf_proj = get_cartopy(u2)

#        plt.figure(figsize=(10, 10))
#        ax = plt.subplot(111, projection=pltcrs)
#        plt.pcolormesh(xlat, xlon, sp_sum_dep, cmap='inferno', zorder=2, transform=wpwrfcrs)
#        plt.show()

# --------------------------------------------------------------------
# --------------------------------------------------------------------
# Plot

        # Map Imagery Download
        class ShadedReliefESRI(GoogleTiles):
            # shaded relief
            def _image_url(self, tile):
                x, y, z = tile
                url = ('https://server.arcgisonline.com/ArcGIS/rest/services/' \
                       'World_Street_Map/MapServer/tile/{z}/{y}/{x}.jpg').format(
                       z=z, y=y, x=x)
                return url

        #fig = plt.figure(figsize=(10, 10))
        #ax = fig.add_subplot(111, projection=pltcrs)
        if ii == 0 and jj == 0:
            fig = plt.figure(figsize=(9, 7))
#        ax = plt.subplot(111, projection=pltcrs)
        if jj == 0:
            ax = plt.subplot(2,2,n_idx, projection=ShadedReliefESRI().crs)
        else:
            ax = plt.subplot(2,2,(2*n_idx), projection=ShadedReliefESRI().crs)
#        ax2 = plt.axes(projection=wrf_proj)
        textdict = dict(c='k', va='center', transform=ax.transAxes, fontsize=12)

        ax.add_image(ShadedReliefESRI(), 12, interpolation='spline36') #the value '11' is the scale, larger the domain smaller the scale should be. Increasing it also increases the computational demand to render.
#        ax.imshow(base, extent=baseext)

        ax.set_xlim([-11717000.139409136, -11701400.535952998])
        ax.set_ylim([4854271.4235172225, 4866000.624060913])
        inProj = CRS.from_epsg(3857)
        outProj = CRS.from_epsg(4326)
        transformer = Transformer.from_crs(inProj,outProj)
        ll_latlon = transformer.transform(-11701400.535952998, 4854271.4235172225)
        ur_latlon = transformer.transform(-11717000.139409136, 4866000.624060913)
        

# Observed perimeter
        pp1 = [
            ax.add_patch(PolygonPatch(ii, fc='None', ec='magenta', zorder=2, lw=3.))
            for ii in polygon
        ]
        plt.plot([],[],color='m',label='OBS Perim',zorder=2,linewidth=3)

# --------------------------------------------------------------------
# --------------------------------------------------------------------
# Plot fire area

        if plottype == 'area':

            if plottype == 'area':
                ll1 = 'Fire Area'

            fxlon, fxlat = refine_dom(wrfxy[0], wrfxy[1], 0.25)
            fucmap = mpl.cm.get_cmap('gray', 2)

#    ax.pcolormesh(fxlon,
#                  fxlat,
#                  np.ma.masked_less_equal(refarea, 0.05),
#                  shading='nearest',
#                  alpha=0.75,
#                  cmap=fucmap,
#                  norm=BoundaryNorm([0, 1], ncolors=2, clip=True),
#                  transform=wpwrfcrs,
#                  rasterized=True,
#                  zorder=1)

            CS = ax.contour(fxlon, fxlat, refarea, [0.06], colors='k', linewidths=3., transform=wpwrfcrs,zorder=3)
            #CS.collections[0].set_label('WRF Perim')
            plt.plot([],[],color='k',zorder=3,label='WRF Perim',linewidth=3)

            #print (np.nansum(sp_sum_dep.ravel()))
            #p1 = ax.pcolormesh(xlon, xlat, sp_sum_dep, vmin=0, vmax=100, cmap='tab10', zorder=3, transform=ccrs.PlateCarree())
            #plt.colorbar(p1,fraction=0.040,pad=0.04)
            #ax.scatter(xlon2,xlat2,s=50,marker='o',fc='orange',ec='gray',transform=ccrs.PlateCarree(),zorder=5,label='WRF Firebrands')

# --------------------------------------------------------------------
# --------------------------------------------------------------------
# --------------------------------------------------------------------
# --------------------------------------------------------------------
# --------------------------------------------------------------------
# --------------------------------------------------------------------
# --------------------------------------------------------------------
# --------------------------------------------------------------------
# ADD VECTORS
        skipx = 15
        skipy = 8
        scale = 70  # bigger makes arrows smaller
        barbw = 0.03

        if ii == 0 and jj == 0:
            wrfinput_files = '/glade/scratch/tjuliano/fire/sagehen/WRF-Fire-LEAPHI/WRF/test/marshall/2021123012_maria/wrfinput_d02'
            dataset3   = Dataset(wrfinput_files, "r")
            cosalpha = dataset3.variables['COSALPHA'][-1,:,:]
            sinalpha = dataset3.variables['SINALPHA'][-1,:,:]
            lat = getvar(dataset3,"lat",meta=False)
            lon = getvar(dataset3,"lon",meta=False)

        uearth1000 = u10*cosalpha - v10*sinalpha
        vearth1000 = v10*cosalpha + u10*sinalpha

        u_src_crs1000 = to_np(uearth1000) / np.cos(to_np(lat) / 180 * np.pi)
        v_src_crs1000 = to_np(vearth1000)
        magnitude1000 = np.sqrt(to_np(uearth1000)**2 + to_np(vearth1000)**2)
        magn_src_crs1000 = np.sqrt(u_src_crs1000**2 + v_src_crs1000**2)

        q = ax.quiver(to_np(lon[::skipy,::skipx]), to_np(lat[::skipy,::skipx]),
                    u_src_crs1000[::skipy, ::skipx] * magnitude1000[::skipy, ::skipx] / magn_src_crs1000[::skipy, ::skipx],
                    v_src_crs1000[::skipy, ::skipx] * magnitude1000[::skipy, ::skipx] / magn_src_crs1000[::skipy, ::skipx],
                    transform=ccrs.PlateCarree(), linewidth=1., zorder=4, units='inches', scale=scale, width=barbw,facecolor='limegreen',edgecolor='black')
        if n_idx == nn_idx:
            qk = ax.quiverkey (q, 1.45, 0.8, 15, '15 m/s', labelpos='E',fontproperties={'weight': 'bold','size':20},zorder=4)
# --------------------------------------------------------------------
# --------------------------------------------------------------------
# --------------------------------------------------------------------
# --------------------------------------------------------------------
# --------------------------------------------------------------------
# --------------------------------------------------------------------
# --------------------------------------------------------------------
# --------------------------------------------------------------------
# ADD VIIRS
        def show_plot4(bbox, viirs_file, a_date, a_time):
            # Read shape file

            bbox2 = [bbox[0], bbox[2], bbox[1], bbox[3]]

            reader = shpreader.Reader(viirs_file, bbox2)

            readeron = list(filter(lambda x:
                                   # x.attributes.get('CONFIDENCE') != 'l'
                                   #           and
                                   x.attributes.get('ACQ_DATE') == a_date #VIIRS data date
                                   and x.attributes.get('ACQ_TIME') == a_time   #VIIRS data time
                                   , reader.records()))

            col2 = {'h': "green", "n": "yellow", 'l': "red"} # Confidence color coding
            colo = list(map(lambda x: col2[x.attributes.get('CONFIDENCE')], readeron))

            ## Plotting (under dev)

            points = [i.geometry for i in readeron]
            ax.scatter(
                [point.x for point in points],
                [point.y for point in points],
                transform=ccrs.PlateCarree(),
                fc=colo,s=50,marker='d',ec='gray',zorder=5,label='VIIRS Active Fire'
            )

        def unique_time(viirs_file, date, bbox):
            bbox2 = [bbox[0], bbox[2], bbox[1], bbox[3]]
            gdf = geopandas.read_file(viirs_file, bbox2)
            # print(gdf.columns.tolist())
            # print(gdf["DAYNIGHT"])
            uk = gdf[gdf.ACQ_DATE == date]
            n = uk.ACQ_TIME
            nu = np.unique(n).tolist()
            return nu


        # Start of the code
        if ii >= 0:
            if ii == 1:
                viirs_file = '/glade/scratch/tjuliano/fire/marshall/perim/sat/j1_viirs/fire_nrt_J1V-C2_277511.shp' # Path to VIIRS shape file
            elif ii == 0 or ii == 2:
                viirs_file = '/glade/scratch/tjuliano/fire/marshall/perim/sat/suomi_viirs/fire_archive_SV-C2_277512.shp'
            bbox = [-105.25, -105.10, 39.9, 40.0] # Domain of interest to plot. This is now for Dixie Fire
            sDATE = dt.datetime.strptime("2021-12-30", '%Y-%m-%d') # Start day of plots
            eDATE = dt.datetime.strptime("2022-01-01", '%Y-%m-%d') # End day of plots

            date_list = pd.date_range(sDATE, eDATE, freq='1D')[:-1]
            hm = []
            for dd in date_list:
                hm_temp = unique_time(viirs_file, dd.strftime('%Y-%m-%d'), bbox)
                hm.append(hm_temp)

       #     for dd in range(len(date_list)):
       #         for tt in range(len(hm[dd])):
       #             if (ii == 2 and tt == 1) or (ii == 3 and tt == 0) or (ii == 4 and tt == 0):
            dd = 0
            if ii == 0:
                tt = 1
            elif ii == 1:
                tt = 0
            elif ii == 2:
                tt == 0
            print (ii, date_list[dd].strftime('%Y-%m-%d'), str(hm[dd][tt]))
            show_plot4(bbox, viirs_file, date_list[dd].strftime('%Y-%m-%d'), str(hm[dd][tt]))

# --------------------------------------------------------------------
# --------------------------------------------------------------------

        ll0 = fnamestr.format(ddhh[init_idx + n_idx]).replace('wrfout_d02_', '')[:-3]

        dt_init = (init_idx - 1) / 4  # -1 because it ends at this output time
        dt_hours = (n_idx) / 4
        ll2 = 'Accum Interval: ' + f'{dt_hours}' + ' hours'

#    tt2 = plt.text(0.02,
#                   0.03,
#                   f'{ll0}UTC  ({ddhhlocal[init_idx+n_idx]} MT)',
#                   ha='left',
#                   **textdict)
#    tt2.set_bbox(dict(facecolor='white', alpha=0.65, edgecolor='none'))
#    tt3 = plt.text(0.995, 0.03, ll2, ha='right', **textdict)
#    tt3.set_bbox(dict(facecolor='white', alpha=0.65, edgecolor='none'))

#    plt.title(thisrun.split('/')[0], **{**textdict, **dict(fontsize=14)})
#        plt.title(f'{ll0}UTC  ({ddhhlocal[init_idx+n_idx]} MT)',**{**textdict, **dict(fontsize=14)})
        plt.title(f'{ll0} UTC',**{**textdict, **dict(fontsize=14)})
        plt.title(f'{ll0} UTC',**{**textdict, **dict(fontsize=14)})
        if n_idx == 2 or n_idx == 3:
            ax.set_xticks(np.linspace(-11717000.139409136,-11701400.535952998,7)) # set longitude indicato
            ax.set_xticklabels('{:0.2f}'.format(abs(float(i)))+u'\N{DEGREE SIGN}'+' W' for i in np.linspace(ur_latlon[1],ll_latlon[1],7))
        if n_idx == 1 or n_idx ==3:
            ax.set_yticks(np.linspace(4866000.624060913,4854271.4235172225,7)) # set latitude indicators
            ax.set_yticklabels('{:0.2f}'.format(abs(float(i)))+u'\N{DEGREE SIGN}'+' N' for i in np.linspace(ur_latlon[0],ll_latlon[0],7))
        ax.tick_params(axis="both", direction='in')
        ax.xaxis.set_tick_params(labelsize=11, rotation=30)
        ax.yaxis.set_tick_params(labelsize=11)

# --------------------------------------------------------------------
# --------------------------------------------------------------------

#        plt.show()

    if ii == 0:
        handles, labels = ax.get_legend_handles_labels()
        fig.legend(handles,labels,loc='center right',bbox_to_anchor=(0.92,0.225),bbox_transform=plt.gcf().transFigure,ncol=1,fontsize=18)
#    plt.show()

    n_idx = n_idx - 1

#fig.savefig(fpath + '/' + fname, bbox_inches='tight', dpi=200, pad_inches=0)
print(fpath + '/' + fname)
plt.tight_layout()
plt.savefig(case + '_' + fname, dpi=600, bbox_inches='tight')

print(fpath + '/' + case + '_' + fname)

#plt.close()
